Skip to content

Fix GPU Hang in Gemma4 and add metrics#40

Merged
solderzzc merged 3 commits into
mainfrom
fix/mtp-gpu-hang
May 12, 2026
Merged

Fix GPU Hang in Gemma4 and add metrics#40
solderzzc merged 3 commits into
mainfrom
fix/mtp-gpu-hang

Conversation

@solderzzc

Copy link
Copy Markdown
Member

Pushes the final GPU hang fix and MTP metric propagation that was missed before the main branch merge.

Copilot AI review requested due to automatic review settings May 12, 2026 21:21

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to (1) prevent a GPU hang in Gemma4’s sparse masked-embedder logit projection by removing a per-row scatter loop, and (2) propagate speculative/MTP draft-token metrics into the generation completion info so callers can observe acceptance rates.

Changes:

  • Adds acceptedDraftTokens / totalDraftTokens tracking to the common token-iterator interface and includes them in GenerateCompletionInfo emitted at the end of async generation.
  • Updates SpeculativeTokenIterator to accumulate accepted/total draft-token counts per speculation round.
  • Replaces a CPU-style per-row scatter loop in Gemma4Text.maskedEmbedderLogits with a vectorized advanced-index assignment.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
Libraries/MLXLMCommon/Evaluate.swift Adds draft-token metrics to iterators and propagates them into GenerateCompletionInfo for async generation.
Libraries/MLXLLM/Models/Gemma4Text.swift Vectorizes candidate-logit scattering into the output vocab tensor to address a GPU hang.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1120 to +1121
let rowIndices = MLXArray(0 ..< Int32(B * S)).reshaped([B * S, 1])
output2D[rowIndices, scatterIdx2D] = selectedLogits2D

@solderzzc solderzzc left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch regarding Int32(B * S)! Using MLXArray.arange(B * S).asType(.int32) correctly bypasses the Swift 32-bit width overflow trap while keeping the MLX pipeline safe. Applied in the latest commit.

@solderzzc solderzzc merged commit 7c45487 into main May 12, 2026
6 checks passed
@solderzzc solderzzc deleted the fix/mtp-gpu-hang branch May 12, 2026 21:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants